/* Copyright 2018 SiFive, Inc */
/* SPDX-License-Identifier: Apache-2.0 */

// extended and changed the code for use in the Crypto Core project

#include <unistd.h>
#include "main.h"
#include "secure/pmp.h"

#include "debugprintf.h"
extern volatile char __mmtext_end;
extern volatile char __mmtext_start;
extern volatile char __mmdata_end;
extern volatile char __mmdata_start;


// sections MMTEXT, MMDATA, MMRODATA are protected via PMP

#define GRANULAR 4
#define PMP_REGIONS 8

namespace PMP {
	namespace {
		uint32_t MMTEXT calcNAPOT(uint32_t base, uint64_t length64) {
			int setbit = -1;

			base >>= 2;
			uint32_t length = (uint32_t) (length64 >> 2ll);

			// zero or length of 4 is invalid for napot
			if (!length)
				return 0xffffffff;


			for (int i=0;i<32;i++) {
				if (length & (1<<i)) {
					// more than one bit set
					if (setbit != -1) {
						return 0xffffffff;
					} else {
						setbit = i;
					}
				}
			}

			// if not zero mask overlaps address -> error
			uint32_t mask = length - 1;
			if (base & mask) {
				return 0xffffffff;
			}

			// clear bits from base
			base &= ~mask;
			// clear upperst bit from mask
			base |=  mask & ~(1<<(setbit-1));
		/*
			debugPrintf("%08x\n", base);
			for (int i=0;i<32;i++) {
				if (!(i % 4))
					debugPrintf(" ");
				if (base & (1<<(31-i))) {
					debugPrintf("1");
				} else {
					debugPrintf("0");
				}
			}
			debugPrintf("\n");
		*/
			return base << 2;	// will get shifted in apply
		}

		int MMTEXT getRegion(
							   unsigned int region,
							   PMPConfig *config,
							   size_t *address)
		{
			size_t pmpcfg = 0;

			if(!config || !address) {
				/* NULL pointers are invalid arguments */
				return 1;
			}

			if(region > PMP_REGIONS) {
				/* Region outside of supported range */
				return 2;
			}

			switch(region / 4) {
			case 0:
				asm("csrr %[cfg], pmpcfg0"
						: [cfg] "=r" (pmpcfg) ::);
				break;
			case 1:
				asm("csrr %[cfg], pmpcfg1"
						: [cfg] "=r" (pmpcfg) ::);
				break;
			case 2:
				asm("csrr %[cfg], pmpcfg2"
						: [cfg] "=r" (pmpcfg) ::);
				break;
			case 3:
				asm("csrr %[cfg], pmpcfg3"
						: [cfg] "=r" (pmpcfg) ::);
				break;
			}

			pmpcfg = (0xFF & (pmpcfg >> (8 * (region % 4)) ) );
			*config = PMPConfig(pmpcfg);

			switch(region) {
			case 0:
				asm("csrr %[addr], pmpaddr0"
						: [addr] "=r" (*address) ::);
				break;
			case 1:
				asm("csrr %[addr], pmpaddr1"
						: [addr] "=r" (*address) ::);
				break;
			case 2:
				asm("csrr %[addr], pmpaddr2"
						: [addr] "=r" (*address) ::);
				break;
			case 3:
				asm("csrr %[addr], pmpaddr3"
						: [addr] "=r" (*address) ::);
				break;
			case 4:
				asm("csrr %[addr], pmpaddr4"
						: [addr] "=r" (*address) ::);
				break;
			case 5:
				asm("csrr %[addr], pmpaddr5"
						: [addr] "=r" (*address) ::);
				break;
			case 6:
				asm("csrr %[addr], pmpaddr6"
						: [addr] "=r" (*address) ::);
				break;
			case 7:
				asm("csrr %[addr], pmpaddr7"
						: [addr] "=r" (*address) ::);
				break;
			case 8:
				asm("csrr %[addr], pmpaddr8"
						: [addr] "=r" (*address) ::);
				break;
			case 9:
				asm("csrr %[addr], pmpaddr9"
						: [addr] "=r" (*address) ::);
				break;
			case 10:
				asm("csrr %[addr], pmpaddr10"
						: [addr] "=r" (*address) ::);
				break;
			case 11:
				asm("csrr %[addr], pmpaddr11"
						: [addr] "=r" (*address) ::);
				break;
			case 12:
				asm("csrr %[addr], pmpaddr12"
						: [addr] "=r" (*address) ::);
				break;
			case 13:
				asm("csrr %[addr], pmpaddr13"
						: [addr] "=r" (*address) ::);
				break;
			case 14:
				asm("csrr %[addr], pmpaddr14"
						: [addr] "=r" (*address) ::);
				break;
			case 15:
				asm("csrr %[addr], pmpaddr15"
						: [addr] "=r" (*address) ::);
				break;
			}

			return 0;
		}


		int MMTEXT setRegion(
							   unsigned int region,
							   PMPConfig config,
							   size_t address)
		{
			PMPConfig old_config;
			size_t old_address;
			size_t cfgmask;
			size_t pmpcfg;
			int rc = 0;

			if(region > PMP_REGIONS) {
				/* Region outside of supported range */
				return 2;
			}

			rc = getRegion(region, &old_config, &old_address);
			if(rc) {
				/* Error reading region */
				return rc;
			}

			if(old_config.isLocked()) {
				/* Cannot modify locked region */
				return 4;
			}

			/* Update the address first, because if the region is being locked we won't
			 * be able to modify it after we set the config */
			if(old_address != address) {
				switch(region) {
				case 0:
					asm("csrw pmpaddr0, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 1:
					asm("csrw pmpaddr1, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 2:
					asm("csrw pmpaddr2, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 3:
					asm("csrw pmpaddr3, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 4:
					asm("csrw pmpaddr4, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 5:
					asm("csrw pmpaddr5, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 6:
					asm("csrw pmpaddr6, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 7:
					asm("csrw pmpaddr7, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 8:
					asm("csrw pmpaddr8, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 9:
					asm("csrw pmpaddr9, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 10:
					asm("csrw pmpaddr10, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 11:
					asm("csrw pmpaddr11, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 12:
					asm("csrw pmpaddr12, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 13:
					asm("csrw pmpaddr13, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 14:
					asm("csrw pmpaddr14, %[addr]"
							:: [addr] "r" (address) :);
					break;
				case 15:
					asm("csrw pmpaddr15, %[addr]"
							:: [addr] "r" (address) :);
					break;
				}
			}

			if(old_config.toInt() != config.toInt()) {
				/* Mask to clear old pmpcfg */
				cfgmask = (0xFF << (8 * (region % 4)) );
				pmpcfg = (config.toInt() << (8 * (region % 4)) );

				switch(region / 4) {
				case 0:
					asm("csrc pmpcfg0, %[mask]"
							:: [mask] "r" (cfgmask) :);

					asm("csrs pmpcfg0, %[cfg]"
							:: [cfg] "r" (pmpcfg) :);
					break;
				case 1:
					asm("csrc pmpcfg1, %[mask]"
							:: [mask] "r" (cfgmask) :);

					asm("csrs pmpcfg1, %[cfg]"
							:: [cfg] "r" (pmpcfg) :);
					break;
				case 2:
					asm("csrc pmpcfg2, %[mask]"
							:: [mask] "r" (cfgmask) :);

					asm("csrs pmpcfg2, %[cfg]"
							:: [cfg] "r" (pmpcfg) :);
					break;
				case 3:
					asm("csrc pmpcfg3, %[mask]"
							:: [mask] "r" (cfgmask) :);

					asm("csrs pmpcfg3, %[cfg]"
							:: [cfg] "r" (pmpcfg) :);
					break;
				}
			}
			return 0;
		}

		void MMTEXT init() {
			PMPConfig init_config("rwxlUO");

			for(unsigned int i = 0; i < PMP_REGIONS; i++) {
				setRegion(i, init_config, 0);
			}
		}


		int MMTEXT lock(unsigned int region)
		{
			PMPConfig config;
			size_t address;
			int rc = 0;

			rc = getRegion(region, &config, &address);
			if(rc) {
				return rc;
			}

			if(config.isLocked() == PMPConfig::LOCKED) {
				return 0;
			}

			config.lock();

			rc = setRegion(region, config, address);

			return rc;
		}


		int MMTEXT setAddress(unsigned int region, size_t address)
		{
			PMPConfig config;
			size_t old_address;
			int rc = 0;

			rc = getRegion(region, &config, &old_address);
			if(rc) {
				return rc;
			}

			rc = setRegion(region, config, address);

			return rc;
		}

		size_t MMTEXT getAddress(unsigned int region)
		{
			PMPConfig config;
			size_t address = 0;

			getRegion(region, &config, &address);

			return address;
		}


		int MMTEXT setAddressMode(unsigned int region, PMPConfig::AddrMode mode)
		{
			PMPConfig config;
			size_t address;
			int rc = 0;

			rc = getRegion(region, &config, &address);
			if(rc) {
				return rc;
			}

			config.setAddressMode(mode);

			rc = setRegion(region, config, address);

			return rc;
		}

		PMPConfig::AddrMode MMTEXT getAddressMode(unsigned int region)
		{
			PMPConfig config;
			size_t address = 0;

			getRegion(region, &config, &address);

			return config.getAddrMode();
		}


		int MMTEXT setExecutable(unsigned int region, bool X)
		{
			PMPConfig config;
			size_t address;
			int rc = 0;

			rc = getRegion(region, &config, &address);
			if(rc) {
				return rc;
			}

			config.setExecutable(X);

			rc = setRegion(region, config, address);

			return rc;
		}

		bool MMTEXT getExecutable(unsigned int region)
		{
			PMPConfig config;
			size_t address = 0;

			getRegion(region, &config, &address);

			return config.isExecutable();
		}


		int MMTEXT setWritable(unsigned int region, bool W)
		{
			PMPConfig config;
			size_t address;
			int rc = 0;

			rc = getRegion(region, &config, &address);
			if(rc) {
				return rc;
			}

			config.setWritable(W);

			rc = setRegion(region, config, address);

			return rc;
		}

		bool MMTEXT getWritable(unsigned int region)
		{
			PMPConfig config;
			size_t address = 0;

			getRegion(region, &config, &address);

			return config.isWritable();
		}


		int MMTEXT setReadable(unsigned int region, bool R)
		{
			PMPConfig config;
			size_t address;
			int rc = 0;

			rc = getRegion(region, &config, &address);
			if(rc) {
				return rc;
			}

			config.setReadable(R);

			rc = setRegion(region, config, address);

			return rc;
		}

		bool MMTEXT getReadable(unsigned int region)
		{
			PMPConfig config;
			size_t address = 0;

			getRegion(region, &config, &address);

			return config.isReadable();
		}

		bool MMTEXT apply(int slot, PMPConfig cfg) {
			// check for invalid NAPOT
			if (cfg.getAddress() == 0xffffffff)
				return false;

			if (setRegion(slot, cfg, cfg.getAddress() >> 2))
				return false;
			return true;
		}

	}


	bool MMTEXT enable() {

		init();

		volatile uint32_t mmtext_end = (uint32_t) &__mmtext_end;
		volatile uint32_t mmtext_start = (uint32_t) &__mmtext_start;

		volatile uint32_t mmdata_end = (uint32_t) &__mmdata_end;
		volatile uint32_t mmdata_start = (uint32_t) &__mmdata_start;

		// smaller PMP with 12 registers
		// PMP uses custom extension which makes shadowing of areas with different
		// privilege levels possible
		// locking bit has different semantics here - it just prevents changing the register
		// priorities: first match counts
		// U-rules: opt-in
		// M-rules: opt-out
		// TODO get addresses from linker-variables
		PMPConfig cfg[PMP_REGIONS]={
#ifdef DEBUG
			// RWX/RwX (0x00000000 to mmdata_start-1)
			PMPConfig("RWXlMT", mmdata_start),	// W for breakpoints
#else
			PMPConfig("RwXlMT", mmdata_start),	// only works without debugging!
#endif
			// RWx (mmdata_start to mmdata_end-1)
			PMPConfig("RWxlMT", mmdata_end),

			// RwX (mmdata_end to 0x0001ffff)
			PMPConfig("RwXlUT", 0x00020000),

			// RWx (0x00020000 to 0x8001effff - (0x00020000 to 0x7fffffff is not decoded))
			PMPConfig("RWxlUT", 0x8001f000),

			// User-Peripherals
			PMPConfig("RWxlUN", calcNAPOT(0xf1000000, 0x100000)),

			// protect from overflowing the machine-mode stack
			PMPConfig("rwxlM4", 0x8001f004),

			// protect from code execution outside of ROM
			PMPAddress(			0x00020000),
			PMPConfig("RWxlMT",	0xf8000000),
		//--------------------------------------------------------------
		};
		for (int i=0;i<PMP_REGIONS;i++) {
			// address marked invalid if napot invalid
			if (!apply(i, cfg[i]))
				return false;
		}

		return true;
	}

}
